import torch
import time
from logging import getLogger
import time
from torch.utils.data import DataLoader
import pickle
from SWTWTEnv import PFSPEnv as Env
from SMTWTModel import SMTWTModel as Model
from utils import get_result_folder, AverageMeter, TimeEstimator
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' )

class SMTWTTester:
    def __init__(self,
                 env_params,
                 model_params,
                 tester_params):
        # save arguments
        self.env_params = env_params
        self.model_params = model_params
        self.tester_params = tester_params

        # result folder, logger
        self.logger = getLogger(name='trainer')
        self.result_folder = get_result_folder()

        self.n_jobs = self.env_params['job_cnt']
        self.pomo_size = self.env_params['pomo_size']
        self.latent_cont_dim = self.model_params['latent_cont_size']
        self.latent_disc_dim = self.model_params['latent_disc_size']
        self.test_batch_size = self.tester_params['test_batch_size']

        # cuda
        USE_CUDA = self.tester_params['use_cuda']
        if USE_CUDA:
            cuda_device_num = self.tester_params['cuda_device_num']
            torch.cuda.set_device(cuda_device_num)
            device = torch.device('cuda', cuda_device_num)
            torch.set_default_tensor_type('torch.cuda.FloatTensor')
        else:
            device = torch.device('cpu')
            torch.set_default_tensor_type('torch.FloatTensor')
        self.device = device

        # ENV and MODEL
        self.env = Env(**self.env_params)
        self.model = Model(**self.model_params)

        model_load = self.tester_params['model_load']
        checkpoint_fullname = '{path}/checkpoint-{epoch}.pt'.format(**model_load)
        checkpoint = torch.load(checkpoint_fullname, map_location=device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.logger.info('...Load Pre-trained model...')

        # utility
        self.time_estimator = TimeEstimator()
        with open(f'./data/test{self.n_jobs}.pkl', 'rb') as f:
           loaded_list = pickle.load(f)
        test_dataset = loaded_list.to(device)
        self.test_dataloader = DataLoader(test_dataset, batch_size=self.test_batch_size, shuffle=False, generator=torch.Generator(device=device))

    def run(self):
        self.time_estimator.reset()
        score_AM = AverageMeter()
        no_aug_AM = AverageMeter()
        epoch=1
        inference_start_t = time.time()
        for problems_batched in self.test_dataloader:
            batch_size = self.test_batch_size
            latent_c_var = torch.empty(batch_size, self.env.pomo_size, self.latent_cont_dim ).uniform_(-1, 1)

            latent_d_var = torch.zeros((batch_size, self.env.pomo_size, self.latent_disc_dim), dtype=torch.float32)
            one_hot_idx = torch.randint(0, self.latent_disc_dim, (batch_size, self.env.pomo_size), dtype=torch.long)
            latent_d_var[torch.arange(batch_size).unsqueeze(1), torch.arange(self.env.pomo_size).unsqueeze(0), one_hot_idx] = 1

            latent_var = torch.cat([latent_d_var, latent_c_var], dim=-1)
            aug_factor = 1
            self.model.eval()
            with torch.no_grad():
                self.env.load_problems_manual(problems_batched)
                reset_state, _, _ = self.env.reset()
                selected_list = torch.zeros(size=(batch_size, self.env.pomo_size, 0), dtype=torch.long)
                self.model.pre_forward(reset_state, latent_var)

                state, reward, done = self.env.pre_step()
                while not done:
                    selected, _= self.model(state, selected_list)

                    state, reward, done = self.env.step(selected)
                    selected_list = torch.cat((selected_list, selected[:, :, None]), dim=2)

                batch_size = batch_size//aug_factor
                aug_reward = reward.reshape(aug_factor, batch_size, self.env.pomo_size)

                max_pomo_reward, _ = aug_reward.max(dim=2)  # get best results from pomo
                # shape: (augmentation, batch)
                no_aug_score = -max_pomo_reward[0, :].float().mean()  # negative sign to make positive value
                max_aug_pomo_reward, _ = max_pomo_reward.max(dim=0) 
                aug_score = -max_aug_pomo_reward.float().mean()

                score_AM.update(aug_score.item(), self.test_batch_size)
                no_aug_AM.update(no_aug_score.item(), self.test_batch_size)

                elapsed_time_str, remain_time_str = self.time_estimator.get_est_string(epoch, self.tester_params['epochs'])
                self.logger.info("episode {:3d}/{:3d}, Elapsed[{}], Remain[{}]".format(epoch, self.tester_params['epochs'], elapsed_time_str, remain_time_str))
                epoch+=1

        self.logger.info(" *** Test Done *** ")
        self.logger.info(" Inference Time(s): {:.4f}s".format(time.time()-inference_start_t))
        self.logger.info(" *** Objective value *** ")
        self.logger.info("Test SCORE: {:.4f} ".format(no_aug_AM.avg))